import argparse

import yaml
from ultralytics import YOLO
from user_det_sg2yolo import sengo2yolo


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--root_path', type=str, default=r'D:\data\231215安全带\trainV8Det_flball_blue', help='root_path')
    par = parser.parse_args()
    root_path = par.root_path

    # root_path = r'D:\data\231215安全带\trainV8Det_flball_blue'
    # root_path = r'/home/ps/zhangxiancai/data/231215anquandai/trainV8Det_flball_blue'

    # data
    s2y = sengo2yolo(glob_str=rf'{root_path}/_add_imgs',
                     yolov5_dir=rf'{root_path}/format_data')# sengo随机裁剪，转为v5格式数据集，生成data.yaml
    s2y.run()
    data_yaml_path = f'{root_path}/format_data/data.yaml'

    # train
    project = f'{root_path}/models'
    weights = f'{root_path}/models/train6/weights/last.pt'
    model = YOLO('yolov8s.yaml').load(weights)  # build from YAML and transfer weights
    # model = YOLO()
    # results = model.train(data=data_yaml_path, epochs=100, imgsz=640, device=1,
    #                       batch=-1, workers=0, project=project)

    results = model.train(data=data_yaml_path, model=weights, epochs=100,
                          imgsz=640, device=1, resume=True,
                          batch=-1, workers=0, project=project)